from abc import ABC, abstractmethod
from typing import Tuple, Optional, Callable

import torch
import torch.nn as nn
from tensordict import TensorDict
from torch import Tensor

from tabasco.flow.path import FlowPath
from tabasco.utils.metric_utils import split_losses_by_time
from tabasco.utils.tensor_ops import mask_and_zero_com, apply_mask

import torch.nn.functional as F


class Interpolant(ABC):
    """Abstract base class for data–noise interpolation.

    Subclasses must implement four domain-specific operations:
    1. sample_noise:    draw a noise tensor matching the data layout;
    2. create_path:     build the interpolation path between two data points for a given time t;
    3. compute_loss:    return a supervised loss for a model prediction along the path;
    4. step:            advance the system one explicit-Euler step during sampling.

    All methods work on batched `TensorDict` objects; the data entry is accessed via
    `key` and its padding mask via `key_pad_mask`.
    """

    def __init__(
        self,
        key: str,
        key_pad_mask: str = "padding_mask",
        loss_weight: float = 1.0,
        time_factor: Optional[Callable] = None,
    ):
        """Initialize the interpolant.

        Args:
            key: key to the data object of interest in the passed batch TensorDict
            key_pad_mask: key to the padding mask in the batch TensorDict
        """
        self.key = key
        self.key_pad_mask = key_pad_mask
        self.time_factor = time_factor
        self.loss_weight = loss_weight

    @abstractmethod
    def sample_noise(self, shape: torch.Size, pad_mask: Tensor) -> Tensor:
        """Draw a random noise tensor compatible with the data layout.

        Args:
            shape: Desired tensor shape, usually `batch[self.key].shape`.
            pad_mask: Boolean/int mask where 1 indicates padded positions; noise must be
                zeroed at these indices.

        Return:
            Tensor: Noise tensor of shape `shape` located on the same device as
                `pad_mask`.
        """
        pass

    @abstractmethod
    def create_path(
        self, x_1: TensorDict, t: Tensor, x_0: Optional[TensorDict] = None
    ) -> Tuple[Tensor, Tensor, Tensor]:
        """Construct the interpolation triple `(x_0, x_t, dx_t)` for time `t`.

        Args:
            x_1: TensorDict containing the reference data point at *t = 1*.
            t: Tensor of shape `(B,)` with interpolation times in `[0, 1]`.
            x_0: Optional TensorDict with a pre-sampled noise state; if `None` a new
                one is drawn via `sample_noise`.

        Return:
            Tuple[Tensor, Tensor, Tensor]:
                * x_0: initial noise state,
                * x_t: interpolated state at time `t`,
                * dx_t: velocity, typically `x_1 - x_0`.
        """
        pass

    @abstractmethod
    def compute_loss(
        self, path: FlowPath, pred: TensorDict, compute_stats: bool = True
    ) -> Tuple[Tensor, dict]:
        """Return a supervised loss for a model prediction at time `t`.

        Args:
            path: FlowPath object generated by `create_path`.
            pred: TensorDict with model outputs that correspond to `path.x_1[self.key]`.
            compute_stats: If `True`, also compute and return auxiliary metrics.

        Return:
            Tuple[Tensor, dict]: Scalar loss and a (possibly empty) statistics
            dictionary.
        """
        pass

    @abstractmethod
    def step(
        self, batch_t: TensorDict, pred: TensorDict, t: Tensor, dt: float
    ) -> Tensor:
        """Advance the sample one explicit-Euler step along the reverse process.

        Args:
            batch_t: TensorDict with the current sample at time `t`.
            pred: Model prediction (same layout as `batch_t`) used to compute the
                velocity field.
            t: Tensor of shape `(B,)` with the current times.
            dt: Scalar or tensor step size applied to each batch element.

        Return:
            Tensor: Updated data tensor corresponding to time `t + dt`.
        """
        pass


class DiscreteInterpolant(Interpolant):
    """Interpolates between two discrete distributions."""

    def __init__(self, **kwargs):
        """Initialize the discrete interpolant.

        Args:
            **kwargs: Forwarded to `Interpolant.__init__`.
        """
        super().__init__(**kwargs)
        self.ce_loss = nn.CrossEntropyLoss(reduction="none")

    def sample_noise(self, shape: torch.Size, pad_mask: Tensor) -> TensorDict:
        """Return uniformly random one-hot noise.

        Args:
            shape: Desired output shape `(…, C)` where `C` equals the number of discrete categories.
            pad_mask: Padding mask; rows with 1s are ignored and set to zeros.

        Returns:
            Tensor: One-hot encoded noise tensor on the same device as `pad_mask`.
        """
        x_0 = torch.randint(0, shape[-1], shape[:-1]).to(pad_mask.device)
        x_0 = F.one_hot(x_0, num_classes=shape[-1])
        x_0 = x_0.to(pad_mask.device)
        return x_0

    def create_path(
        self, x_1: Tensor, t: Tensor, x_0: Optional[TensorDict] = None
    ) -> FlowPath:
        """Create a path for a ground truth point and a time step."""
        if x_0 is None:
            x_0_tensor = self.sample_noise(x_1[self.key].shape, x_1[self.key_pad_mask])
        else:
            x_0_tensor = x_0[self.key]

        t = t.unsqueeze(-1)
        assert t.shape == (x_1[self.key].shape[0], 1), (
            f"t shape: {t.shape} != {(x_1[self.key].shape[0], 1)}"
        )

        _corruption_prob = torch.rand(x_1[self.key].shape[:-1]).to(x_1[self.key].device)
        corrupt_mask = (_corruption_prob > t).unsqueeze(-1).int()

        x_t = x_0_tensor * corrupt_mask + x_1[self.key] * (1 - corrupt_mask)

        dx_t = x_1[self.key] - x_0_tensor

        return x_0_tensor, x_t, dx_t

    def compute_loss(
        self, path: FlowPath, pred: TensorDict, compute_stats: bool = False
    ) -> Tensor:
        """Cross-entropy loss between prediction and ground truth.

        Args:
            path: FlowPath from `create_path` (only `path.x_1` is required).
            pred: Model logits with shape `(B, N, C)`.
            compute_stats: Whether to return an empty stats dict (always empty here).

        Returns:
            Tuple[Tensor, dict]: Mean loss over molecules and an empty statistics dict.
        """

        real_mask = 1 - path.x_1[self.key_pad_mask].int()
        n_atoms = real_mask.sum(dim=-1)

        loss = self.ce_loss(
            pred[self.key].transpose(1, 2), path.x_1[self.key].argmax(dim=-1)
        )
        per_mol_loss = (loss * real_mask).sum(dim=-1) / n_atoms

        if self.time_factor:
            per_mol_loss = per_mol_loss * self.time_factor(path.t)

        total_loss = per_mol_loss.mean() * self.loss_weight

        stats_dict = {}

        return total_loss, stats_dict

    def step(self, batch_t: TensorDict, pred: TensorDict, t: Tensor, dt: float):
        """Stochastic forward-Euler step for discrete states in continuous time.

        Args:
            batch_t: TensorDict containing one-hot states at time `t`.
            pred: Logits predicting the terminal distribution.
            t: Tensor `(B,)` with current time.
            dt: Step size to advance.

        Returns:
            Tensor: One-hot tensor representing the new discrete state.
        """
        t = t.unsqueeze(-1).unsqueeze(-1)
        dt = dt.unsqueeze(-1).unsqueeze(-1)
        assert dt.shape == t.shape == (batch_t[self.key].shape[0], 1, 1), (
            f"t shape: {t.shape}, dt shape: {dt.shape}, batch_t shape: {batch_t[self.key].shape}"
        )

        x1_probs = torch.nn.functional.softmax(pred[self.key], dim=-1)
        curr_state = batch_t[self.key].argmax(dim=-1)

        step_probs = ((dt / (1 - t)) * x1_probs).clamp(max=1.0)
        step_probs.scatter_(-1, curr_state[:, :, None], 0.0)
        step_probs.scatter_(
            -1, curr_state[:, :, None], 1.0 - step_probs.sum(dim=-1, keepdim=True)
        )
        step_probs = step_probs.clamp(min=0.0)

        step_probs = step_probs / step_probs.sum(dim=-1, keepdim=True)

        x_next = torch.distributions.Categorical(step_probs).sample()
        x_next = F.one_hot(x_next, num_classes=batch_t[self.key].shape[-1])

        return x_next


class CenteredMetricInterpolant(Interpolant):
    """Linear interpolation between two points in Euclidean space.

    This class teaches the model to predict the endpoint of the path.
    """

    # def mean_w_mask(self, x, padding_mask):
    #     """Compute masked mean along the atom dimension.

    #     Padded atoms (`padding_mask == 1`) are ignored. The result keeps a broadcastable
    #     shape `(*, 1, D)` and padded rows are zeroed for numerical safety.

    #     Args:
    #         x: Tensor `(*, N, D)` of coordinates.
    #         padding_mask: Bool/int tensor `(*, N)`.

    #     Returns:
    #         Tensor: Mean coordinates with same dtype/device as `x`.
    #     """
    #     real_mask = (1 - padding_mask.int())[..., None]  # [*, N, 1]
    #     padding_mask = padding_mask[..., None]  # [*, N, 1]

    #     num_elements = real_mask.sum(dim=-2, keepdim=True)  # [*, 1, 1]
    #     num_elements = torch.where(num_elements == 0, torch.tensor(1.0), num_elements)

    #     x_masked = torch.masked_fill(x, padding_mask, 0.0)

    #     mean = torch.sum(x_masked, dim=-2, keepdim=True) / num_elements  # [*, 1, D]
    #     mean = torch.masked_fill(mean, padding_mask, 0.0)
    #     return mean

    # def apply_mask(self, x, padding_mask):
    #     """Zero out padded atoms.

    #     Args:
    #         x: Tensor `(*, N, D)`.
    #         padding_mask: Bool/int tensor `(*, N)`.

    #     Returns:
    #         Tensor: Masked tensor.
    #     """
    #     real_mask = (1 - padding_mask.int())[..., None]  # [*, N, 1]
    #     return x * real_mask

    # def mask_and_zero_com(self, x, padding_mask):
    #     """Center coordinates and zero padded atoms.

    #     Args:
    #         x: Tensor `(*, N, D)`.
    #         padding_mask: Bool/int tensor `(*, N)`.

    #     Returns:
    #         Tensor: Centered and masked coordinates.
    #     """
    #     real_mask = (1 - padding_mask.int())[..., None]  # [*, N, 1]

    #     x = self.apply_mask(x, padding_mask)  # [*, N, D]
    #     mean = self.mean_w_mask(x, padding_mask)  # [*, 1, D]

    #     centered_x = (x - mean) * real_mask
    #     return centered_x

    def __init__(
        self,
        centered: bool = True,
        scale_noise_by_log_num_atoms: bool = False,
        noise_scale: float = 1.0,
        **kwargs,
    ):
        """Initialize the metric interpolant.

        Args:
            centered: If True, subtract center-of-mass so translation is ignored.
            scale_noise_by_log_num_atoms: Scale noise amplitude by `log(N_atoms)`.
            noise_scale: Standard deviation of the sampled Gaussian noise.
            **kwargs: Forwarded to `Interpolant.__init__`.
        """
        super().__init__(**kwargs)
        self.mse_loss = nn.MSELoss(reduction="none")
        self.centered = centered
        self.scale_noise_by_log_num_atoms = scale_noise_by_log_num_atoms
        self.noise_scale = noise_scale

    def sample_noise(self, shape: torch.Size, pad_mask: Tensor) -> TensorDict:
        """Return masked Gaussian noise with optional scaling.

        Args:
            shape: Desired output shape.
            pad_mask: Padding mask.

        Returns:
            Tensor: Noise tensor.
        """
        x_0 = torch.randn(shape).to(pad_mask.device) * self.noise_scale

        if self.scale_noise_by_log_num_atoms:
            num_atoms = (~pad_mask).sum(dim=-1)
            x_0 = x_0 * torch.log(num_atoms[..., None, None])

        if self.centered:
            x_0 = mask_and_zero_com(x_0, pad_mask)
        else:
            x_0 = apply_mask(x_0, pad_mask)

        return x_0

    def create_path(
        self, x_1: Tensor, t: Tensor, x_0: Optional[TensorDict] = None
    ) -> FlowPath:
        """Generate `(x_0, x_t, dx_t)` via linear interpolation in Euclidean space."""

        if x_0 is None:
            x_0_tensor = self.sample_noise(x_1[self.key].shape, x_1[self.key_pad_mask])
        else:
            x_0_tensor = x_0[self.key]

        t = t.unsqueeze(-1).unsqueeze(-1)
        assert t.shape == (x_1[self.key].shape[0], 1, 1), (
            f"t shape: {t.shape} != {(x_1[self.key].shape[0], 1, 1)}"
        )

        x_0_tensor = mask_and_zero_com(x_0_tensor, x_1[self.key_pad_mask])
        x_1_tensor = mask_and_zero_com(x_1[self.key], x_1[self.key_pad_mask])

        x_t = (1.0 - t) * x_0_tensor + t * x_1_tensor
        dx_t = x_1_tensor - x_0_tensor

        return x_0_tensor, x_t, dx_t

    def compute_loss(
        self, path: FlowPath, pred: TensorDict, compute_stats: bool = True
    ) -> Tensor:
        """Mean-squared error on masked coordinates with optional time weighting."""

        real_mask = 1 - path.x_1[self.key_pad_mask].int()
        n_atoms = real_mask.sum(dim=-1)

        err = (pred[self.key] - path.x_1[self.key]) * real_mask.unsqueeze(-1)
        loss = torch.sum(err**2, dim=(-1, -2)) / (n_atoms * err.shape[-1])

        if self.time_factor:
            loss = loss * self.time_factor(path.t)

        if compute_stats:
            binned_losses = split_losses_by_time(path.t, loss, 5)
            stats_dict = {
                **{
                    f"coords_loss_bin_{i}": loss for i, loss in enumerate(binned_losses)
                },
            }
        else:
            stats_dict = {}

        total_loss = loss.mean() * self.loss_weight
        return total_loss, stats_dict

    def step(self, batch_t: TensorDict, pred: TensorDict, t: Tensor, dt: float):
        """Deterministic forward-Euler step for continuous coordinates."""

        t = t.unsqueeze(-1).unsqueeze(-1)
        dt = dt.unsqueeze(-1).unsqueeze(-1)
        assert dt.shape == t.shape == (batch_t[self.key].shape[0], 1, 1), (
            f"t shape: {t.shape}, dt shape: {dt.shape}, batch_t shape: {batch_t[self.key].shape}"
        )

        x1_pred = pred[self.key]
        velocity = (x1_pred - batch_t[self.key]) / (1 - t)

        x_new = batch_t[self.key] + velocity * dt
        x_new = mask_and_zero_com(x_new, batch_t[self.key_pad_mask])

        assert x_new.shape == batch_t[self.key].shape, (
            f"x_new shape: {x_new.shape} != {batch_t[self.key].shape}"
        )

        return x_new


class SDEMetricInterpolant(CenteredMetricInterpolant):
    """CenteredMetricInterpolant with Langevin/SDE-style sampling based on the proteina paper."""

    def __init__(
        self,
        langevin_sampling_schedule: Optional[Callable] = None,
        white_noise_sampling_scale: float = 1.0,
        **kwargs,
    ):
        """Initialize the SDE interpolant with Langevin sampling parameters.

        Args:
            langevin_sampling_schedule: Function that returns the sampling schedule for the score.
            white_noise_sampling_scale: Standard deviation of the sampled white noise.
            **kwargs: Forwarded to `Interpolant.__init__`.
        """
        super().__init__(**kwargs)
        self.mse_loss = nn.MSELoss(reduction="none")
        self.white_noise_sampling_scale = white_noise_sampling_scale

        if langevin_sampling_schedule is None:
            self.langevin_sampling_schedule = lambda t: torch.zeros_like(t)
        else:
            self.langevin_sampling_schedule = langevin_sampling_schedule

    def calculate_score(self, v_t, x_t, t):
        """Return the diffusion score `(t * v_t - x_t) / (1 - t)` as used in Proteina."""
        return (t * v_t - x_t) / (1 - t + 1e-6)

    def step(self, batch_t: TensorDict, pred: TensorDict, t: Tensor, dt: float):
        """Forward Euler integration step with score components and white noise injection."""

        t = t.unsqueeze(-1).unsqueeze(-1)
        dt = dt.unsqueeze(-1).unsqueeze(-1)
        assert dt.shape == t.shape == (batch_t[self.key].shape[0], 1, 1), (
            f"t shape: {t.shape}, dt shape: {dt.shape}, batch_t shape: {batch_t[self.key].shape}"
        )

        x1_pred = pred[self.key]
        velocity = (x1_pred - batch_t[self.key]) / (1 - t)

        score = self.calculate_score(velocity, batch_t[self.key], t)

        component_score = self.langevin_sampling_schedule(t) * score

        wiener_noise_scale = torch.sqrt(
            2 * self.langevin_sampling_schedule(t) * self.white_noise_sampling_scale
        ) * torch.randn_like(batch_t[self.key])
        white_noise = (
            self.sample_noise(batch_t[self.key].shape, batch_t[self.key_pad_mask])
            * wiener_noise_scale
        )

        x_new = (
            batch_t[self.key] + velocity * dt + component_score * dt + white_noise * dt
        )
        x_new = mask_and_zero_com(x_new, batch_t[self.key_pad_mask])

        assert x_new.shape == batch_t[self.key].shape, (
            f"x_new shape: {x_new.shape} != {batch_t[self.key].shape}"
        )

        return x_new
